import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
PATH_DATASETS = "" # 预设路径
BATCH_SIZE = 1024  # 批量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)



# 下载测试资料
test_ds = MNIST(PATH_DATASETS, train=False, download=False,
                 transform=transforms.ToTensor())


criterion = nn.CrossEntropyLoss()


model = torch.load('model.pt')
test_loader = DataLoader(test_ds,shuffle=False,batch_size=BATCH_SIZE)
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.argmax(1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
data_count = len(test_loader.dataset)
percentage = 100. * correct / data_count
print(f'平均损失: {test_loss:.4f}, 准确率: {correct}/{data_count}' +
      f' ({percentage:.0f}%)\n')